import numpy as np
import torch
import abc
import math

import acs.utils as utils
import acs.methods as opt


class CoresetConstruction(metaclass=abc.ABCMeta):
    def __init__(self, acq, data, posterior, **kwargs):
        """
        Base class for constructing active learning batches.
        :param acq: (function) Acquisition function.
        :param data: (ActiveLearningDataset) Dataset.
        :param posterior: (function) Function to compute posterior mean and covariance.
        :param kwargs: (dict) Additional arguments.
        """
        self.acq = acq
        self.posterior = posterior
        self.kwargs = kwargs

        self.save_dir = self.kwargs.pop('save_dir', None)
        train_idx, unlabeled_idx = data.index['train'], data.index['unlabeled']
        self.X_train, self.y_train = data.X[train_idx], data.y[train_idx]
        self.X_unlabeled, self.y_unlabeled = data.X[unlabeled_idx], data.y[unlabeled_idx]
        self.theta_mean, self.theta_cov = self.posterior(self.X_train, self.y_train, **self.kwargs)
        self.scores = np.zeros(len(self.X_unlabeled))

    def build(self, M=1, **kwargs):
        """
        Constructs a batch of points to sample from the unlabeled set.
        :param M: (int) Batch size.
        :param kwargs: (dict) Additional arguments.
        :return: (list of ints) Selected data point indices.
        """
        self._init_build(M, **kwargs)
        w = np.zeros([len(self.X_unlabeled), 1])
        for m in range(M):
            w = self._step(m, w, **kwargs)

        # print(w[w.nonzero()[0]])
        return w.nonzero()[0]

    @abc.abstractmethod
    def _init_build(self, M, **kwargs):
        """
        Performs initial computations for constructing the AL batch.
        :param M: (int) Batch size.
        :param kwargs: (dict) Additional arguments.
        """
        pass

    @abc.abstractmethod
    def _step(self, m, w, **kwargs):
        """
        Adds the m-th element to the AL batch. This method is also used by non-greedy, batch AL methods
        as it facilitates plotting the selected data points over time.
        :param m: (int) Batch iteration.
        :param w: (numpy array) Current weight vector.
        :param kwargs: (dict) Additional arguments.
        :return:
        """
        return None


class Random(CoresetConstruction):
    def __init__(self, acq, data, posterior, **kwargs):
        """
        Constructs a batch of points using random (uniform) sampling.
        :param acq: (function) Acquisition function.
        :param data: (ActiveLearningDataset) Dataset.
        :param posterior: (function) Function to compute posterior mean and covariance.
        :param kwargs: (dict) Additional arguments.
        """
        super().__init__(acq, data, posterior, **kwargs)
        self.scores = np.ones(len(self.X_unlabeled),)

    def _init_build(self, M=1, seed=None):
        """
        Randomly selects unlabeled data points.
        :param M: (int) Batch size.
        :param seed: (int) Numpy random seed.
        """
        if seed is not None:
            np.random.seed(seed)

        idx = np.random.choice(len(self.scores), M, replace=False)
        self.counts = np.zeros_like(self.scores)
        self.counts[idx] = 1.  # assign selected data points a count of 1

    def _step(self, m, w, **kwargs):
        """
        Adds the m-th selected data point to the batch.
        :param m: (int) Batch iteration.
        :param w: (numpy array) Current weight vector.
        :param kwargs: (dict) Additional arguments.
        :return: (numpy array) Weight vector after adding m-th data point to the batch.
        """
        if m <= len(self.counts.nonzero()[0]):
            w[np.argsort(-self.counts)[m]] = 1.

        return w


class ImportanceSampling(CoresetConstruction):
    def __init__(self, acq, data, posterior, **kwargs):
        """
        Constructs a batch of points using importance sampling.
        :param acq: (function) Acquisition function.
        :param data: (ActiveLearningDataset) Dataset.
        :param posterior: (function) Function to compute posterior mean and covariance.
        :param kwargs: (dict) Additional arguments.
        """
        super().__init__(acq, data, posterior, **kwargs)
        self.scores = self.acq(self.theta_mean, self.theta_cov, self.X_unlabeled, **self.kwargs)

    def _init_build(self, M=1, seed=None):
        """
        Samples counts of unlabeled data points according to acquisition function scores using importance sampling.
        :param M: (int) Batch size.
        :param seed: (int) Numpy random seed.
        """
        if seed is not None:
            np.random.seed(seed)

        self.counts = np.random.multinomial(M, self.scores / np.sum(self.scores))

    def _step(self, m, w, **kwargs):
        """
        Adds the data point with the m-th most counts to the batch.
        :param m: (int) Batch iteration.
        :param w: (numpy array) Current weight vector.
        :param kwargs: (dict) Additional arguments.
        :return: (numpy array) Weight vector after adding m-th data point to the batch.
        """
        if m <= len(self.counts.nonzero()[0]):
            w[np.argsort(-self.counts)[m]] = 1.

        return w

        
class Argmax(ImportanceSampling):
    """
    Constructs a batch of points by selecting the M highest-scoring points according to the acquisition function.
    """
    def _init_build(self, M=1, seed=None):
        pass

    def _step(self, m, w, **kwargs):
        """
        Adds the data point with the m-th highest score to the batch.
        :param m: (int) Batch iteration.
        :param w: (numpy array) Current weight vector.
        :param kwargs: (dict) Additional arguments.
        :return: (numpy array) Weight vector after adding m-th data point to the batch.
        """
        w[np.argsort(-self.scores)[m]] = 1.
        return w


class ProjectedFrankWolfe(object):
    def __init__(self, model, data, J, **kwargs):
        """
        Constructs a batch of points using ACS-FW with random projections. Note the slightly different interface.
        :param model: (nn.module) PyTorch model.
        :param data: (ActiveLearningDataset) Dataset.
        :param J: (int) Number of projections.
        :param kwargs: (dict) Additional arguments.
        """
        self.ELn, self.entropy = model.get_projections(data, J, **kwargs)
        squared_norm = torch.sum(self.ELn * self.ELn, dim=-1)
        self.sigmas = torch.sqrt(squared_norm + 1e-6)
        self.sigma = self.sigmas.sum()
        self.EL = torch.sum(self.ELn, dim=0)

        # for debugging
        self.model = model
        self.data = data

    def _init_build(self, M, **kwargs):
        pass  # unused

    def build(self, M=1, **kwargs):
        """
        Constructs a batch of points to sample from the unlabeled set.
        :param M: (int) Batch size.
        :param kwargs: (dict) Additional parameters.
        :return: (list of ints) Selected data point indices.
        """
        self._init_build(M, **kwargs)
        w = utils.to_gpu(torch.zeros([len(self.ELn), 1]))
        norm = lambda weights: (self.EL - (self.ELn.t() @ weights).squeeze()).norm()
        for m in range(10*M):
            w = self._step(m, w)
            if len(w.nonzero()[:, 0])==M:
                break  
        print('num of data queried before randum fill', len(w.nonzero()[:, 0]), '\n')
        if len(w.nonzero()[:, 0]) < M:
            remaining = (w[:, 0] == 0).nonzero().view(1,-1)[0]
            more_idx = torch.randperm(len(remaining))[:M-len(w.nonzero()[:, 0])]
            more_idx = remaining[more_idx]
            w[more_idx] += 1
        
        # print(w[w.nonzero()[:, 0]].cpu().numpy())
        print('|| L-L(w)  ||: {:.4f}'.format(norm(w)))
        print('|| L-L(w1) ||: {:.4f}'.format(norm((w > 0).float())))
        print('Avg pred entropy (pool): {:.4f}'.format(self.entropy.mean().item()))
        print('Avg pred entropy (batch): {:.4f}'.format(self.entropy[w.flatten() > 0].mean().item()))
        try:
            logdet = torch.slogdet(self.model.linear._compute_posterior()[1])[1].item()
            print('logdet weight cov: {:.4f}'.format(logdet))
        except TypeError:
            pass

        supp = w.nonzero()[:, 0]
        return supp.cpu().numpy(), w[supp, 0].cpu().numpy()

    def _step(self, m, w, **kwargs):
        """
        Applies one step of the Frank-Wolfe algorithm to update weight vector w.
        :param m: (int) Batch iteration.
        :param w: (numpy array) Current weight vector.
        :param kwargs: (dict) Additional arguments.
        :return: (numpy array) Weight vector after adding m-th data point to the batch.
        """
        self.ELw = (self.ELn.t() @ w).squeeze()
        scores = (self.ELn / self.sigmas[:, None]) @ (self.EL - self.ELw)
        f = torch.argmax(scores)
        gamma, f1 = self.compute_gamma(f, w)
        # print('f: {}, gamma: {:.4f}, score: {:.4f}'.format(f, gamma.item(), scores[f].item()))
        if np.isnan(gamma.cpu()):
            raise ValueError

        w = (1 - gamma) * w + gamma * (self.sigma / self.sigmas[f]) * f1
        return w

    def compute_gamma(self, f, w):
        """
        Computes line-search parameter gamma.
        :param f: (int) Index of selected data point.
        :param w: (numpy array) Current weight vector.
        :return: (float, numpy array) Line-search parameter gamma and f-th unit vector [0, 0, ..., 1, ..., 0]
        """
        f1 = torch.zeros_like(w)
        f1[f] = 1
        Lf = (self.sigma / self.sigmas[f] * f1.t() @ self.ELn).squeeze()
        Lfw = Lf - self.ELw
        numerator = Lfw @ (self.EL - self.ELw)
        denominator = Lfw @ Lfw
        return numerator / denominator, f1


class SparseCoreset(object):
    def __init__(self, model, data, J, alpha, beta, optimization_method, zero_mean=True, **kwargs):
        """
        Constructs a batch of points using SABAL with random projections.
        :param model: (nn.module) PyTorch model.
        :param data: (ActiveLearningDataset) Dataset.
        :param J: (int) Number of projections.
        :param alpha: (float) scaling factor for the variance term
        :param beta: (float) scaling factor for the weight regularizer
        :param optimization_method: (str) to use greedy or proximal IHT for sparse approximation
        :param zero_mean: (bool) whether to zero mean the projection matrix, always true for SABAL
        :param kwargs: (dict) Additional arguments.
        """
        self.model = model
        self.data = data
        self.J = J
        self.alpha = alpha
        self.beta = beta
        self.optimization_method = optimization_method
        self.zero_mean = zero_mean
        self.projection_vector = None
        self.variance = None

        self.projection_vector, self.variance = model.get_projections_SABAL(data, self.J, self.zero_mean, **kwargs)
        self.projection_vector = (self.projection_vector.T).cpu()
        self.variance = torch.squeeze(self.variance).cpu()

    def build(self, num_queried=1, **kwargs):
        """
        Constructs a batch of points to sample from the unlabeled set.
        :param num_queried: (int) number of queried data each AL iteration
        :param kwargs: (dict) Additional arguments.
        """
        k = num_queried
        Phi = self.projection_vector[0:round(0.9*self.projection_vector.shape[0])]
        Phi_valid = self.projection_vector[round(0.9*self.projection_vector.shape[0]):]
        y = Phi.sum(dim=1).reshape([-1, 1])/self.projection_vector.shape[1]
        y_valid = Phi_valid.sum(dim=1).reshape([-1, 1])/self.projection_vector.shape[1]
        Phi /= k
        Phi_valid /= k
        Phi = Phi / y.norm()
        y = y / y.norm()
        Phi_valid = Phi_valid / y_valid.norm()
        y_valid = y_valid / y_valid.norm()

        variance_idx = torch.argsort(self.variance, descending=True)[:k]
        variance = self.variance / torch.sum(self.variance[variance_idx])
        print( 'mean varance/largest varaince\n', torch.mean(self.variance)*k/torch.sum(self.variance[variance_idx]) )

        if self.optimization_method == 'prox_iht':
            w, supp = opt.proximal_iht(Phi, y, variance, k, self.alpha, self.beta, verbose=True, reg_type='one')                
        elif self.optimization_method == 'greedy':
            sigma = Phi.pow(2).sum(dim=0).pow(0.5)
            #sigma = torch.ones(self.unlabeled_length, dtype=Phi.dtype, device=self.device)
            L = sigma.sum().item()
            w, supp = opt.greedy(Phi, y, variance, k, self.alpha, sigma, L, self.beta, verbose=True, reg_type='one') 
        else:  
            raise ValueError

        # optimization results
        print('optimization results:')
        print('{} items are selected; maximal coreset size k = {}.'.format(len(supp), k))
        weights_display_num = min(len(supp), 15)
        print('weights selected (first {} are displayed) are {} ...'.format(weights_display_num, w[supp[:weights_display_num]].reshape([-1])))
        w_normalize = w.reshape([-1])  # ideally no need to normalize
        print('weights after normalization (first {} are displayed) are {} ...'.format(weights_display_num, w_normalize[supp[:weights_display_num]]))
        w_min = w_normalize[supp].min().item()
        w_max = w_normalize[supp].max().item()
        mean_deviation = (w_normalize[supp] - 1).abs().mean().item()
        print('weights has minimum {}, maximum {}, and mean deviation {}'.format(w_min, w_max, mean_deviation))
        f, f1, f2 = opt.obj(y, Phi, variance, self.alpha, w, supp)
        print('training objective (f1 + alpha * f2) is {}, approximation loss (f1) is {}, selected variance loss '
              '(alpha * f2) is {}, selected original variance (f2) is {}'.format(f, f1, f2, f2 / (self.alpha + 1e-30)))
        
        f_valid, f1_valid, f2_valid = opt.obj(y_valid, Phi_valid, variance, self.alpha, w, supp)
        print('validation objective (f1 + alpha * f2) is {}, approximation loss (f1) is {}, selected variance loss '
              '(alpha * f2) is {}, selected original variance (f2) is {}'.format(f_valid, f1_valid, f2_valid, f2_valid / (self.alpha + 1e-30)))
        results = {'train_total_loss': f, 'train_approximation_loss': f1, 'train_variance_loss': f2,
                   'valid_total_loss': f_valid, 'valid_approximation_loss': f1_valid, 'valid_variance_loss': f2_valid,
                   'w_min': w_min, 'w_max': w_max, 'w_mean_deviation': mean_deviation, 'w': w.cpu().numpy(), 'w_normalize': w_normalize.cpu().numpy(),
                   'supp': supp, 'selected_coreset_size': len(supp), 'maximal_coreset_size':k,
                   'number_proj_train': Phi.shape[0], 'number_proj_valid': Phi_valid.shape[0], 'alpha': self.alpha, 'beta': self.beta
                  }

        return supp, w_normalize[supp]

